import os
import random
import ast
import torch
import numpy as np
from datasets import load_dataset
from torch.utils.data import Subset

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Synthetic Data Generation Functions
def generate_synthetic_example(seq_length=10, K=5, noise_range=100, idx=None):
    tokens = []
    # Choose s0 and a non-zero common difference d.
    s0 = random.randint(0, K - 1)
    d = random.randint(1, K - 1)  # ensure difference is not zero
    signals = []
    
    for t in range(seq_length):
        # Compute the t-th signal in the arithmetic progression.
        s_t = (s0 + t * d) % K
        signals.append(s_t)
        
        # Generate noise for this token.
        noise = random.randint(0, noise_range - 1)
        token = f"S{s_t}_N{noise}"
        tokens.append(token)
    
    # Create context from all tokens except the last one.
    context = " ".join(tokens[:-1])
    # The target is the signal of the final token as a string.
    target = str(signals[-1])
    
    example = {"context": context, "target": target}
    if idx is not None:
        example["id"] = str(idx)
    return example

def generate_synthetic_dataset(num_examples, seq_length=10, K=5, noise_range=100):
    return [generate_synthetic_example(seq_length, K, noise_range, idx=i) for i in range(num_examples)]

def load_data(input_file=None, sample_count=200, dataset_split="train", dataset_type="clutrr", seed=42):
    prompts = []
    combined_texts = []

    if input_file and os.path.exists(input_file):
        with open(input_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line:
                    prompts.append(line)
                    combined_texts.append(line)
    else:

        if dataset_type.lower() == "clutrr":
            dataset = load_dataset("CLUTRR/v1", "gen_train23_test2to10", split=dataset_split.split("_")[0])
            dataset = dataset.shuffle(seed=seed) 
            for example in dataset:
                task_name = example.get("task_name", "")
                # Split filtering logic
                if dataset_split == "test_id" and task_name not in ["task_1.2", "task_1.3"]:
                    continue  # only include simple tasks for test_id
                elif dataset_split == "test_ood" and task_name in ["task_1.2", "task_1.3"]:
                    continue  # exclude simple tasks for test_ood
                story = example["story"]
                target_text = example["target_text"]
                query = example.get("query", None)
                query_str = "What is the relationship? Answer:"
                if query is not None:
                    if isinstance(query, str) and query.strip().startswith("("):
                        try:
                            parsed_query = ast.literal_eval(query)
                            if isinstance(parsed_query, (list, tuple)) and len(parsed_query) >= 2:
                                query_str = f"What is the relationship between {parsed_query[0]} and {parsed_query[1]}? Answer:"
                        except Exception as e:
                            print(f"Error parsing query '{query}': {e}")
                    else:
                        query_str = f"What is the relationship between {query}? Answer:"
                prompt = f"Story: {story}\nQuery: {query_str}"
                combined = prompt + " " + target_text
                prompts.append(prompt)
                combined_texts.append(combined)
                if len(prompts) >= sample_count:
                    break

        elif dataset_type.lower() == "synthetic":
            set_seed(seed) 
            if dataset_split == "train":
                K = 13
                for i in range(sample_count):
                    ex = generate_synthetic_example(seq_length=10, K=K, noise_range=100, idx=i)
                    # use comment prompt and combined for gpt
                    # prompt = ex["context"]
                    # combined = prompt + " " + ex["target"]
                    prompt = ex["context"] + " "
                    combined = prompt + ex["target"]
                    prompts.append(prompt)
                    combined_texts.append(combined)
            elif dataset_split in ["test_id", "test_ood", "test"]:
                prompts_id, combined_texts_id = [], []
                prompts_ood, combined_texts_ood = [], []

                if dataset_split in ["test_id", "test"]:
                    K_id = 13
                    sample_count_test = sample_count/2
                    for i in range(sample_count_test):
                        ex = generate_synthetic_example(seq_length=10, K=K_id, noise_range=100, idx=i)
                        # use comment prompt and combined for gpt
                        # prompt = ex["context"]
                        # combined = prompt + " " + ex["target"]
                        prompt = ex["context"] + " "
                        combined = prompt + ex["target"]
                        prompts_id.append(prompt)
                        combined_texts_id.append(combined)

                if dataset_split in ["test_ood", "test"]:
                    sample_count_test = sample_count/2
                    for i in range(sample_count_test):
                        K_ood = random.choice([k for k in range(5, 26) if k != 13])
                        # K_ood = 17
                        ex = generate_synthetic_example(seq_length=10, K=K_ood, noise_range=100, idx=i + 10000)
                        # use comment prompt and combined for gpt
                        # prompt = ex["context"]
                        # combined = prompt + " " + ex["target"]
                        prompt = ex["context"] + " "
                        combined = prompt + ex["target"]
                        prompts_ood.append(prompt)
                        combined_texts_ood.append(combined)

                # Combine if 'test'
                if dataset_split == "test":
                    prompts = prompts_id + prompts_ood
                    combined_texts = combined_texts_id + combined_texts_ood
                elif dataset_split == "test_id":
                    prompts = prompts_id
                    combined_texts = combined_texts_id
                elif dataset_split == "test_ood":
                    prompts = prompts_ood
                    combined_texts = combined_texts_ood
                
            elif dataset_split == "validation":
                K = 13
                sample_count=2000
                for i in range(sample_count):
                    ex = generate_synthetic_example(seq_length=10, K=K, noise_range=100, idx=i)
                    # use comment prompt and combined for gpt
                    # prompt = ex["context"]
                    # combined = prompt + " " + ex["target"]
                    prompt = ex["context"] + " "
                    combined = prompt + ex["target"]
                    prompts.append(prompt)
                    combined_texts.append(combined)
                    
        elif dataset_type.lower() == 'ecqa':
            dataset = load_dataset("tasksource/ecqa", split=dataset_split)
            dataset = dataset.shuffle(seed=seed) 
            for example in dataset:
                question = example.get("q_text", "")
                options = [example.get("q_op1", ""), example.get("q_op2", ""), example.get("q_op3", ""), example.get("q_op4", ""), example.get("q_op5", "")]
                answer_text = example.get("q_ans", "")

                # Build options text
                options_text = "\n".join([f"{chr(65+j)}. {opt}" for j, opt in enumerate(options)])
                prompt = f"Question: {question}\nOptions:\n{options_text}\nAnswer:"

                # Find the correct option letter
                correct_letter = None
                for idx, opt in enumerate(options):
                    if opt.strip().lower() == answer_text.strip().lower():
                        correct_letter = chr(65 + idx)
                        break
                if correct_letter is None:
                    correct_letter = "A"  # fallback

                target_text = correct_letter

                combined = prompt + " " + target_text

                prompts.append(prompt)
                combined_texts.append(combined)

                if len(prompts) >= sample_count:
                    break

        else:
            raise ValueError(f"Unknown dataset type: {dataset_type}. Supported types are 'clutrr', 'gsm8k', 'wordnet', 'synthetic', and 'commonsenseqa'.")

    if len(prompts) > sample_count:
        indices = random.sample(range(len(prompts)), sample_count)
        prompts = [prompts[i] for i in indices]
        combined_texts = [combined_texts[i] for i in indices]

    print(f"Loaded {len(prompts)} samples from {dataset_type} dataset ({dataset_split} split)")
    return prompts, combined_texts
